# datasets/h5ad_generic.py
from __future__ import annotations

from typing import Optional, Dict, Any

import numpy as np

from .base import DatasetSpec
from .registry import register
from .transforms import default_preprocess

def _require_scanpy():
    try:
        import scanpy as sc  # noqa: F401
    except Exception as e:
        raise RuntimeError(
            "This loader requires 'scanpy' (and anndata). Install with:\n"
            "    pip install scanpy anndata\n"
        ) from e

@register("h5ad")
def load_h5ad(
    cache_dir: Optional[str] = None,
    *,
    path: str,
    X_key: Optional[str] = None,     # e.g., "X_pca" or None for .X
    label_key: Optional[str] = None, # e.g., "cell_type"
    batch_key: Optional[str] = None, # e.g., "batch"
    preprocess: bool = True,
    pca_n: Optional[int] = 50,
    random_state: int = 0,
) -> DatasetSpec:
    """
    Generic H5AD loader. Reads .X or an .obsm slot (e.g., 'X_pca').
    Returns DatasetSpec with labels/batch if available.
    """
    _require_scanpy()
    import scanpy as sc

    ad = sc.read_h5ad(path)
    if X_key is None:
        X = ad.X
        X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
    else:
        X = ad.obsm[X_key]
        X = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
    X = X.astype(np.float32, copy=False)

    labels = None
    batch = None
    if label_key and label_key in ad.obs:
        labels = ad.obs[label_key].to_numpy()
    if batch_key and batch_key in ad.obs:
        batch = ad.obs[batch_key].to_numpy()

    meta: Dict[str, Any] = {
        "obs_keys": list(ad.obs.columns),
        "var_names": ad.var_names.to_list() if ad.var_names is not None else None,
        "path": path,
    }

    if preprocess:
        Xp, info = default_preprocess(X, pca_n=pca_n, random_state=random_state)
        meta.update(info)
        return DatasetSpec(name="h5ad", X=Xp, labels=labels, batch=batch, meta=meta)
    return DatasetSpec(name="h5ad", X=X, labels=labels, batch=batch, meta=meta)
